在洛谷过审了:https://www.luogu.com.cn/article/tl32jtl5

首先不难发现:

ANS=i=0n1(i+1)×P(i)=i=0n1P(xi)ANS = \sum_{i = 0}^{n - 1} (i + 1) \times P(i) = \sum_{i = 0}^{n - 1} P(x \geq i)

其中:

P(xk)=B1×B2××Bk×k!S×(S1)×(Sk+1)P(x \geq k) = \frac{B_1 \times B_2 \times \dots \times B_k \times k!}{S \times (S - 1) \times \dots (S - k + 1)}

这里,S=AiS=\sum A_i

分母很好理解,就是总共的方案数,然后 BB 就对应了我们选的颜色的分别的袜子的数量。

所以这个式子我们要想计算,就需要计算出原数组这 NN 个元素中选 kk 个求积的和。考虑 DP,dpi,jdp_{i,j} 表示前 ii 个元素中选了 jj 个,那转移很简单:

dpi,j=dpi1,j1×Ai+dpi1,jdp_{i,j} = dp_{i - 1, j - 1} \times A_i + dp_{i - 1, j}

其中 dp0,0dp_{0,0} 初始化为 11

但这样时间复杂度就炸了。观察发现我们 DP 的过程,其实我们就是在每次对原本的这个式子乘上 (Ai×x+1)(A_i \times x + 1) 这个多项式,所以我们 DP 的过程其实就等价于求:

i=1NAi\prod_{i = 1}^{N} A_i

这个式子的各项系数,直接使用分治 NTT 即可解决,时间复杂度 O(Nlog2N)\mathrm{O}(N \log^2 N)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int mod = 998244353;
const int G = 3;
int fastpow(int a, int b, int mod)
{
int res = 1;
while (b) {
if (b & 1) {
res = (res * a) % mod;
}
a = (a * a) % mod;
b >>= 1;
}
return res;
}
int rev[1 << 22];
void change(vector<int> &a, int len)
{
for (int i = 0; i < len; i++) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (len >> 1) : 0);
}
for (int i = 0; i < len; i++) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}
}
void ntt(vector<int> &a, int len, int x)
{
change(a, len);
for (int h = 2; h <= len; h <<= 1) {
int omega = fastpow(G, (mod - 1) / h, mod);
if (x == -1) {
omega = fastpow(omega, mod - 2, mod);
}
for (int i = 0; i < len; i += h) {
int w = 1;
for (int j = i; j < i + h / 2; j++) {
int u = a[j];
int v = a[j + h / 2] * w % mod;
a[j] = (u + v) % mod;
a[j + h / 2] = (u - v + mod) % mod;
w = (w * omega) % mod;
}
}
}
if (x == -1) {
int inv = fastpow(len, mod - 2, mod);
for (int i = 0; i < len; i++) {
a[i] = (a[i] * inv) % mod;
}
}
}
vector<int> convo(vector<int> a, vector<int> b)
{
if (a.empty() || b.empty()) {
return {0};
}
int m = 1;
while (m < a.size() + b.size() - 1) {
m <<= 1;
}
a.resize(m);
b.resize(m);
ntt(a, m, 1);
ntt(b, m, 1);
for (int i = 0; i < m; i++) {
a[i] = (a[i] * b[i]) % mod;
}
ntt(a, m, -1);
return a;
}
vector<int> solve(vector<int>& a, int l, int r)
{
if (l > r) {
return {1};
}
if (l == r) {
return {1, a[l]};
}
int mid = (l + r) / 2;
return convo(solve(a, l, mid), solve(a, mid + 1, r));
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n;
cin >> n;
vector<int> a(n);
int cnt = 0;
for (int i = 0; i < n; i++) {
cin >> a[i];
cnt += a[i];
}
vector<int> res = solve(a, 0, n - 1);
int ans = 1, s1 = 1, s2 = 1;
for (int i = 1; i <= n && i < res.size(); i++) {
s1 = (s1 * i) % mod;
s2 = s2 * (cnt - i + 1) % mod;
ans = (ans + (res[i] * s1) % mod * fastpow(s2, mod - 2, mod)) % mod;
}
cout << (ans + mod) % mod << '\n';
return 0;
}

AC记录